import os
import time
from copy import deepcopy
from pathlib import Path

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

from datasets.common import RAW_DATA_PATH
from datasets.frozen_embeddings.loader import EmbeddingDataset, load_embedding_dataset
from experiments.experimental_pipeline import DEFAULT_CACHE, DEFAULT_RESULTS
from src.logistic_regression import LogisticRegression
from datasets.frozen_embeddings.embeddings.nlp import embed_text_bert

# Configuration
DATASET_NAME = EmbeddingDataset.IMDB
MAX_SAMPLES = 10000
TOP_K = 100
SEED = 42

def save_debug_sets(prefix: str, data: dict, imdb_texts: list, save_dir: Path):
    """Saves both the .npz and .csv versions of selected data samples."""
    npz_path = save_dir / f"{prefix}.npz"
    csv_path = save_dir / f"{prefix}.csv"

    # Save .npz with vectors
    np.savez(
        npz_path,
        original_indices=data["original_indices"],
        test_indices=data["test_indices"],
        labels=data["labels"],
        vectors=data["vectors"],
        losses=data["losses"],
    )

    # Save .csv with text instead of vectors
    df = pd.DataFrame({
        "original_index": data["original_indices"],
        "test_index": data["test_indices"],
        "label": data["labels"],
        "loss": data["losses"],
        "text": [imdb_texts[i] for i in data["original_indices"]],
    })
    df.to_csv(csv_path, index=False)

    print(f"Saved: {npz_path}")
    print(f"Saved: {csv_path}")


def main(seed=SEED):
    print("Initializing directories...")
    results_path = DEFAULT_RESULTS / DATASET_NAME
    os.makedirs(results_path, exist_ok=True)
    debug_path = DEFAULT_CACHE / DATASET_NAME / "debug"
    os.makedirs(debug_path, exist_ok=True)

    print("Loading dataset...")
    experiment = load_embedding_dataset(DATASET_NAME, max_samples=MAX_SAMPLES)

    train_X, train_y = experiment.train.features, experiment.train.labels
    test_X, test_y = experiment.test.features, experiment.test.labels
    train_indices = experiment.train.original_indices
    test_indices = experiment.test.original_indices

    print(f"Train shape: {train_X.shape}, Test shape: {test_X.shape}")

    # Plot and save label distributions
    plt.figure(figsize=(8, 4))
    plt.hist(train_y, bins=np.arange(-0.5, 2, 1), alpha=0.6, label="Train Labels", edgecolor='black')
    plt.hist(test_y, bins=np.arange(-0.5, 2, 1), alpha=0.6, label="Test Labels", edgecolor='black')
    plt.xticks([0, 1])
    plt.xlabel("Label")
    plt.ylabel("Count")
    plt.title("Label Distribution in Train and Test Sets")
    plt.legend()
    histogram_path = debug_path / "label_distribution.png"
    plt.savefig(histogram_path)
    plt.close()


    print("Training model...")
    model = LogisticRegression(train_X, train_y, regularization=1e-5, fit_intercept=False, reg_type="L2")
    model.fit(verbose=True)

    # Compute losses
    print("Computing test and train losses...")
    test_losses = model.model.get_model_losses(test_X, test_y)
    train_losses = model.model.get_model_losses(train_X, train_y)

    # Compute Hessian and its inverse
    print("Computing Hessian and inverse...")
    H = model.n * model.compute_hessian()
    H_inv = model.compute_hessian_inv() / model.n
    # Identify top-K high-loss test samples
    # Compute gradients for test set
    print("Computing test set gradients...")
    test_gradients = model.compute_gradients(test_X, test_y)

    # Compute influence vectors: H_inv @ grad
    print("Computing influence vector norms...")
    influence_vecs = test_gradients @ H_inv  # shape: (num_test, d)
    influence_norms = np.linalg.norm(influence_vecs, axis=1)

    # Get model predictions (0 or 1)
    test_preds = (model.model.get_model_predictions(test_X) >= 0.5).astype(int)

    # Filter correctly classified indices
    correct_mask = (test_preds == test_y)
    correct_indices = np.where(correct_mask)[0]
    correct_norms = influence_norms[correct_indices]

    # Sort correctly classified by influence norm (descending)
    sorted_idx = np.argsort(-correct_norms)
    top_test_idx = correct_indices[sorted_idx[:TOP_K]]
    top_test_original_indices = test_indices[top_test_idx]

    # Random 50 test samples (excluding top loss ones)
    rng = np.random.default_rng(seed)
    random_test_idx = rng.choice(sorted_idx[:-TOP_K], TOP_K, replace=False)
    random_test_original_indices = test_indices[random_test_idx]

    # Load IMDB raw text
    print("Loading raw IMDB text data...")
    imdb_df = pd.read_csv(RAW_DATA_PATH / "imdb.csv")
    imdb_texts = imdb_df["review"].tolist()
    imdb_labels = imdb_df["sentiment"].map({"positive": 1, "negative": 0}).tolist()

    # Prepare test sample data
    top_test_data = {
        "original_indices": top_test_original_indices,
        "test_indices": top_test_idx,
        "labels": test_y[top_test_idx],
        "vectors": test_X[top_test_idx],
        "losses": test_losses[top_test_idx],
    }

    random_test_data = {
        "original_indices": random_test_original_indices,
        "test_indices": random_test_idx,
        "labels": test_y[random_test_idx],
        "vectors": test_X[random_test_idx],
        "losses": test_losses[random_test_idx],
    }

    # Save both debug datasets
    save_debug_sets("top50_test", top_test_data, imdb_texts, results_path)
    save_debug_sets("random50_test", random_test_data, imdb_texts, results_path)



    hessian_save_path = results_path / "imdb_hessian_and_inverse.npz"
    np.savez(hessian_save_path, hessian=H, hessian_inv=H_inv)

    print(f"Done. Results saved to: {results_path}")

    # k = 250
    # rng = np.random.default_rng(seed)
    # save_path = results_path / "test_poisoning_sweep"
    # save_path.mkdir(exist_ok=True)
    #
    # # Select k random test indices
    # rand_test_indices = rng.choice(len(test_X), size=k, replace=False)
    #
    # records = []
    #
    # print(f"\nRunning poisoning analysis on {k} random test samples...")
    # t0 = time.time()
    # for count, idx in enumerate(rand_test_indices):
    #     print(f"[{count + 1}/{k}]\tTest index: {idx} \tElapsed time: {time.time() - t0:.1f}")
    #
    #     x_test = test_X[idx]
    #     y_test_true = test_y[idx]
    #     y_test_flipped = 1 - y_test_true
    #
    #     # Poisoned training set
    #     X_poisoned = np.vstack([train_X, x_test[None, :]])
    #     y_poisoned = np.append(train_y, y_test_flipped)
    #
    #     # Retrain
    #     model_poisoned = LogisticRegression(
    #         X_poisoned, y_poisoned, regularization=1e-5,
    #         fit_intercept=False, reg_type="L2"
    #     )
    #     model_poisoned.model = deepcopy(model.model)
    #     model_poisoned.fit(warm_start=True, verbose=False)
    #
    #     # Hessian and inv
    #     H_inv_poison = model_poisoned.compute_hessian_inv() / model_poisoned.n
    #
    #     # Gradients
    #     grad_test = model_poisoned.compute_gradients(x_test[None, :], np.array([y_test_true]))[0]
    #     grads_train = model_poisoned.compute_gradients(X_poisoned, y_poisoned)
    #
    #     # IF and RIF estimates
    #     if_estimate = -x_test @ H_inv_poison @ grads_train[-1]
    #     y_hat = model_poisoned.model.get_model_predictions(X_poisoned)
    #     beta = np.sqrt(y_hat * (1 - y_hat))
    #     g = (x_test.T @ H_inv_poison @ x_test) * (beta[-1] ** 2)
    #     rif_estimate = if_estimate / (1 - g)
    #
    #     # Actual prediction delta
    #     pred_orig = model.model.get_model_predictions(x_test[None, :])[0]
    #     pred_poison = model_poisoned.model.get_model_predictions(x_test[None, :])[0]
    #     delta_pred = pred_poison - pred_orig
    #     delta_z = np.inner(model_poisoned.model.weights - model.model.weights, x_test)
    #
    #
    #     original_idx = test_indices[idx]
    #     raw_text = imdb_texts[original_idx]
    #
    #     records.append({
    #         "test_idx": idx,
    #         "original_idx": original_idx,
    #         "actual_delta_pred": delta_z,
    #         "if_estimate": if_estimate,
    #         "rif_estimate": rif_estimate,
    #         "true_label": y_test_true,
    #         "text": raw_text
    #     })
    #
    # # Save results
    # df_out = pd.DataFrame(records)
    # csv_out = save_path / "poison_prediction_change.csv"
    # df_out.to_csv(csv_out, index=False)
    # print(f"\nSaved results to {csv_out}")
    #
    # plt.figure(figsize=(10, 6), dpi=400)
    # plt.scatter(df_out["actual_delta_pred"], df_out["if_estimate"], alpha=0.6, label="IF Estimate", marker='o')
    # plt.scatter(df_out["actual_delta_pred"], df_out["rif_estimate"], alpha=0.6, label="RIF Estimate", marker='x')
    #
    # # Axes lines
    # plt.axhline(0, color='gray', linestyle='-', linewidth=1)
    # plt.axvline(0, color='gray', linestyle='-', linewidth=1)
    #
    # # y = x reference
    # plt.plot(df_out["actual_delta_pred"], df_out["actual_delta_pred"], color='black', linestyle='--', linewidth=1,
    #          label='y = x')
    #
    # plt.xlabel("Actual ΔPrediction (Poisoned vs Original)")
    # plt.ylabel("Estimated ΔPrediction")
    # plt.title("Actual vs Predicted Change in Prediction")
    # plt.legend()
    # plt.tight_layout()
    # plt.savefig(save_path / "influence_vs_actual_scatter.png")
    # plt.savefig(save_path / "influence_vs_actual_scatter.pdf")
    # plt.close()
    #
    # print(f"Saved scatter plot to {save_path}")

    # return

    # ---- Augmentation Analysis Begins ----

    print("Loading and embedding augmentations...")
    aug_csv_path = results_path / "review_transformations.csv"
    aug_df = pd.read_csv(aug_csv_path)
    # Drop any rows where 'Transformation' or 'Review' is missing
    aug_df = aug_df.dropna(subset=["Transformation", "Review"])
    aug_df["Review"] = aug_df["Review"].str.replace("\n", "<br />")
    texts = aug_df["Review"].tolist()
    transformations = aug_df["Transformation"].tolist()
    # Ensure all are strings and no NaNs
    texts = [str(t) if pd.notnull(t) else "" for t in texts]
    aug_embeddings = embed_text_bert(texts)
    aug_embeddings = np.hstack([aug_embeddings, np.ones((aug_embeddings.shape[0], 1))])  # Add intercept
    aug_embeds = dict(zip(transformations, aug_embeddings))

    # Confirm match with test sample at index 706
    test_idx = 620
    original_embed = aug_embeds["Original"]
    test_vec = experiment.test.features[test_idx]
    print("Verifying new embedding method matches original...")

    # Compute differences
    diff = original_embed - test_vec
    abs_diff = np.abs(diff)
    max_diff = np.max(abs_diff)
    mean_diff = np.mean(abs_diff)
    nonzero = np.sum(abs_diff > 1e-6)

    print("🛠 DEBUG INFO:")
    print(f"Original embed shape: {original_embed.shape}")
    print(f"Test[test_idx] embed shape: {test_vec.shape}")
    print(f"Max abs difference: {max_diff}")
    print(f"Mean abs difference: {mean_diff}")
    print(f"Number of differing dimensions (> 1e-6): {nonzero}")

    # Optional: save to disk
    np.savez(results_path / "embedding_debug.npz",
             original=original_embed,
             test706=test_vec,
             diff=diff)

    # Optional: print where they differ most
    top_diff_indices = np.argsort(-abs_diff)[:5]
    for i in top_diff_indices:
        print(f"Index {i}: original={original_embed[i]:.6f}, test={test_vec[i]:.6f}, diff={abs_diff[i]:.6f}")

    # Plot both vectors
    plt.figure(figsize=(12, 4), dpi=400)
    plt.plot(original_embed, label="Original (from CSV)")
    plt.plot(test_vec, label="Test[test_idx] (from dataset)", alpha=0.7)
    plt.legend()
    plt.title("Embedding Comparison: Original vs Test[test_idx]")
    plt.xlabel("Embedding Dimension")
    plt.ylabel("Value")
    plt.tight_layout()
    plt.savefig(results_path / "embedding_comparison_plot.png")
    plt.savefig(results_path / "embedding_comparison_plot.pdf")
    plt.close()

    # Load IMDB text
    imdb_df = pd.read_csv(RAW_DATA_PATH / "imdb.csv")
    imdb_texts = imdb_df["review"].tolist()

    # Text from dataset
    text_dataset = imdb_texts[experiment.test.original_indices[test_idx]]

    # Text from augmentations CSV
    text_csv = aug_df[aug_df["Transformation"] == "Original"]["Review"].values[0]

    # Show differences
    import difflib
    diffs = list(difflib.unified_diff(
        text_csv.splitlines(),
        text_dataset.splitlines(),
        fromfile='CSV Original',
        tofile='Test[test_idx]',
        lineterm=''
    ))

    print("\n".join(diffs[:30]))  # Show a preview of diff

    print(f"Saved plot and embedding diff to: {results_path}")
    assert np.allclose(original_embed, experiment.test.features[test_idx], atol=1e-4), \
        "Original embedding does not match test sample at index test_idx"


    true_label = experiment.test.labels[test_idx]
    x_test = experiment.test.features[test_idx]
    y_test = true_label

    if_debug = {}
    save_path = results_path / "poison_analysis"
    save_path.mkdir(parents=True, exist_ok=True)

    influence_summaries = {}

    print("Starting poisoned retraining loop...")
    for name, x_poison in aug_embeds.items():
        # if name == "Original":
        #     continue

        print(f"\nEvaluating augmentation: {name}")
        # Flip the label
        flipped_label = 1 - y_test
        X_poisoned = np.vstack([experiment.train.features, x_poison[None, :]])
        y_poisoned = np.append(experiment.train.labels, flipped_label)

        # Retrain with warm start
        model_poisoned = LogisticRegression(
            X_poisoned, y_poisoned,  regularization=1e-5, fit_intercept=False, reg_type="L2"
        )
        model_poisoned.model = deepcopy(model.model)
        model_poisoned.fit(warm_start=True, verbose=False)

        # Compute Hessian + inverse
        H_poison = model_poisoned.n * model_poisoned.compute_hessian()
        H_inv_poison = model_poisoned.compute_hessian_inv() / model_poisoned.n

        # Gradients
        grads_train = model_poisoned.compute_gradients(X_poisoned, y_poisoned)
        grad_test = model_poisoned.compute_gradients(x_test[None, :], np.array([y_test]))[0]

        # Influence estimate
        if_estimate = -x_test @ H_inv_poison @ grads_train[-1]

        # Rescaled IF (RIF) — only for last point
        y_hat = model_poisoned.model.get_model_predictions(X_poisoned)
        alpha = y_poisoned - y_hat
        beta = np.sqrt(y_hat * (1 - y_hat))
        g = (x_poison.T @ H_inv_poison @ x_poison) * (beta[-1] ** 2)
        print(f"{g=}")

        # Actual change in prediction
        pred_orig = model.model.get_model_predictions(x_test)
        pred_poison = model_poisoned.model.get_model_predictions(x_test)
        delta_pred = pred_poison - pred_orig

        delta_z = np.inner(model_poisoned.model.weights - model.model.weights, x_test)

        # Influence of all other train points
        influences_all = grad_test @ H_inv_poison @ grads_train[:-1].T
        max_abs_if = np.max(np.abs(influences_all))

        # Save data
        influence_summaries[name] = {
            "actual_delta_pred": delta_z,
            "if_estimate": if_estimate,
            "rif_estimate": if_estimate / (1 - g),
            "max_abs_if": max_abs_if,
        }

        # Save raw npz
        import re

        # Sanitize name for file usage
        safe_name = re.sub(r"[^a-zA-Z0-9_\-]", "_", name)
        np.savez(
            save_path / f"{safe_name}_analysis.npz",
            hessian=H_poison,
            hessian_inv=H_inv_poison,
            x_poison=x_poison,
            grad_test=grad_test,
            if_estimate=if_estimate,
            rif=if_estimate / (1 - g),
            actual_delta=delta_pred,
            influences_all=influences_all
        )

    # Create barplot

    labels = []
    actual = []
    if_vals = []
    rif_vals = []
    max_if_vals = []

    for name, vals in influence_summaries.items():
        labels.append(name)
        actual.append(-vals["actual_delta_pred"])
        if_vals.append(-vals["if_estimate"])
        rif_vals.append(-vals["rif_estimate"])
        max_if_vals.append(vals["max_abs_if"])

    x = np.arange(len(labels))
    width = 0.2

    plt.figure(figsize=(12, 6), dpi=400)
    plt.bar(x - 1.5 * width, actual, width, label="Actual ΔPrediction")
    plt.bar(x - 0.5 * width, if_vals, width, label="IF Estimate")
    plt.bar(x + 0.5 * width, rif_vals, width, label="RIF Estimate")
    # plt.bar(x + 1.5 * width, max_if_vals, width, label="Max |IF| (others)")

    plt.xticks(x, labels, rotation=45, ha="right")
    plt.ylabel("Prediction Change")
    plt.title("Influence Analysis of Augmentations on Test Sample")
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_path / "poisoning_barplot.png")
    plt.savefig(save_path / "poisoning_barplot.pdf")
    plt.close()

    print(f"\nAnalysis complete. Results in: {save_path}")


if __name__ == "__main__":
    main()
